library(conText)
library(tidyverse)
library(ggthemes)
library(tm)
library(tidylog)

# corpus
twts_corpus <- readRDS("data/analysis/MPtweets_corpus.rds")

# (GloVe) pre-trained embeddings
pre_trained <- readRDS("data/wordembeddings/glove.rds")

# transformation matrix
transform_matrix <- readRDS("data/wordembeddings/khodakA.rds")

#---------------------------------
# build context corpus
#---------------------------------
  
target = "climate"

# find contexts for pre-GS
contextpreGS <- get_context(x = twts_corpus$tweet[twts_corpus$postgs == 0L], 
                        target = target, 
                        window = 6, valuetype = "fixed", case_insensitive = TRUE, 
                        hard_cut = FALSE, verbose = TRUE)

# find contexts for post-GS
contextpostGS <- get_context(x = twts_corpus$tweet[twts_corpus$postgs == 1L], 
                        target = target, 
                        window = 6, valuetype = "fixed", case_insensitive = TRUE, 
                        hard_cut = FALSE, verbose = TRUE)

# bind contexts
contexts_corpus <- rbind(cbind(contextpreGS, GS = 'pre'), 
                         cbind(contextpostGS, GS = 'post'))

#---------------------------------
# get local vocab (we'll use it to define the candidates for nns)
#---------------------------------

local_vocab <- get_local_vocab(c(contextpreGS$context, 
                                 contextpostGS$context), pre_trained)

#---------------------------------
# contrast_nns
#---------------------------------

set.seed(123L)
contrast_target <- contrast_nns(context1 = contextpostGS$context, 
                                context2 = contextpreGS$context, 
                                pre_trained, transform_matrix, transform = TRUE, 
                                bootstrap = TRUE, num_bootstraps = 20, 
                                permute = TRUE, num_permutations = 100, 
                                candidates = local_vocab, norm = "l2")

# first get pre-post nearest neighbors (output by the contrast_nns function)
nnspost <- contrast_target$nns1
nnspre <- contrast_target$nns2

N <- 100
# subset to the union of top N nearest neighbors for each period
top_nns <- union(nnspre$Term[1:N], nnspost$Term[1:N])

# identify which of these are shared
shared_nns <- intersect(nnspre$Term[1:N], nnspost$Term[1:N])

# subset nns_ratio (output by contrast_nns) to the union of the top nearest neighbors
nns_ratio <- contrast_target$nns_ratio %>%
  filter(Term %in% top_nns) %>%
  mutate(group = case_when(Term %in% nnspre$Term[1:N] & !(Term %in% nnspost$Term[1:N]) ~ 'Pre-GS',
                           !(Term %in% nnspre$Term[1:N]) & Term %in% nnspost$Term[1:N] ~ 'Post-GS',
                           Term %in% shared_nns  ~ 'shared'),
         significant = if_else(Empirical_Pvalue < 0.01, 'yes', 'no'))

# order Terms by Estimate
nns_ratio <- nns_ratio %>% 
  mutate(absdev = abs(1 - Estimate)) %>% 
  arrange(-absdev) %>% 
  mutate(tokenID = 1:nrow(.)) %>% 
  mutate(Term_Sig = if_else(significant == 'yes', paste0(Term, "*"), Term))

# get ratios significant
nns_ratio_nsh <- nns_ratio %>%
  filter(significant=="yes")

## alternative plot

pt <- ggplot(aes(x = Estimate, y = tokenID, color = group, shape = group), size = 2,
       data=nns_ratio_nsh) +
  geom_point(aes(x = Estimate, y = tokenID, color = group, shape = group), size = 3,
             data=nns_ratio_nsh) +
  geom_segment(aes(x=1, xend=Estimate, y=tokenID, yend=tokenID, color = group), size=1, alpha=.5) +
  geom_vline(xintercept = 1, colour = 'black', linetype = "dashed", size = 0.5) +
  geom_text(aes(x = Estimate, y = tokenID, label=Term_Sig), data=nns_ratio_nsh,
            hjust = if_else(nns_ratio_nsh$Estimate>1, -0.2, 1.2), vjust = 0.25, size = 5,
            show.legend = FALSE) +
  scale_color_manual(values=c("#1b733e","black", "grey")) +
  xlim(0,2.5) +
  ylim(0,50) + #show top 50 by similarity ratio
  ylab('') +
  labs(title = "Tweets") +
  xlab("cosine similarity ratio \n (post-GS/pre-GS)") +
  theme(panel.background = element_blank(),
        plot.title = element_text(size=18, hjust = 0.5, face="bold"),
        axis.text.x = element_text(size=16),
        axis.text.y = element_text(size=16),
        axis.title.y = element_text(size=16, 
                                    margin = margin(t = 0, r = 15, b = 0, l = 15)),
        axis.title.x = element_text(size=16, 
                                    margin = margin(t = 15, r = 0, b = 15, l = 0)),
        legend.text=element_text(size=16),
        legend.title=element_blank(),
        legend.key=element_blank(),
        legend.position = "bottom",
        legend.direction = "horizontal",
        legend.spacing.x = unit(0.25, 'cm'),
        plot.margin=unit(c(1,1,0,0),"cm"))

save(pt, file = "data/output/contrast_ctweet.RData")